{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "37dfe75a",
   "metadata": {},
   "source": [
    "# 1-3,文本数据建模流程范例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87488d10",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import datetime\n",
    "\n",
    "#打印时间\n",
    "def printbar():\n",
    "    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "    print(\"\\n\"+\"==========\"*8 + \"%s\"%nowtime)\n",
    "\n",
    "#mac系统上pytorch和matplotlib在jupyter中同时跑需要更改环境变量\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\" \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76dfe3a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install torchtext==0.11.0\n",
    "!pip install torchkeras==3.2.3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95288a6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import torchtext \n",
    "import torchkeras \n",
    "print(\"torch.__version__ = \", torch.__version__)\n",
    "print(\"torchtext.__version__ = \", torchtext.__version__) \n",
    "print(\"torchkeras.__version__ = \", torchkeras.__version__) \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f27c47e0",
   "metadata": {},
   "source": [
    "```\n",
    "torch.__version__ =  1.10.0\n",
    "torchtext.__version__ =  0.11.0\n",
    "torchkeras.__version__ =  3.2.3\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47c4202e",
   "metadata": {},
   "source": [
    "<br>\n",
    "\n",
    "<font color=\"red\">\n",
    " \n",
    "公众号 **算法美食屋** 回复关键词：**pytorch**， 获取本项目源码和所用数据集百度云盘下载链接。\n",
    "    \n",
    "</font> \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bd0289f",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "be150a8a",
   "metadata": {},
   "source": [
    "### 一，准备数据"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c21091b",
   "metadata": {},
   "source": [
    "imdb数据集的目标是根据电影评论的文本内容预测评论的情感标签。\n",
    "\n",
    "训练集有20000条电影评论文本，测试集有5000条电影评论文本，其中正面评论和负面评论都各占一半。\n",
    "\n",
    "文本数据预处理较为繁琐，包括文本切词，构建词典，编码转换，序列填充，构建数据管道等等。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "758a8aa6",
   "metadata": {},
   "source": [
    "在torch中预处理文本数据可以借助torchtext中的词典工具并自定义Dataset。\n",
    "\n",
    "下面进行演示。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9132822a",
   "metadata": {},
   "source": [
    "![](./data/电影评论.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1eef3f86",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import pandas as pd \n",
    "import torch \n",
    "from torchtext.data.utils import get_tokenizer\n",
    "from torchtext.vocab import build_vocab_from_iterator\n",
    "\n",
    "MIN_FREQ = 30      #仅考虑词频超过30的词\n",
    "MAX_LEN = 200      #每个样本保留200个词的长度\n",
    "BATCH_SIZE = 20 \n",
    "\n",
    "\n",
    "dftrain = pd.read_csv(\"./eat_pytorch_datasets/imdb/train.tsv\",sep=\"\\t\",header = None,names = [\"label\",\"text\"])\n",
    "dfval = pd.read_csv(\"./eat_pytorch_datasets/imdb/test.tsv\",sep=\"\\t\",header = None,names = [\"label\",\"text\"])\n",
    "\n",
    "\n",
    "#1，文本切词\n",
    "tokenizer = get_tokenizer('basic_english')\n",
    "\n",
    "\n",
    "#2，构建词典        \n",
    "PAD_IDX,UNK_IDX = 0,1\n",
    "special_symbols = ['<pad>','<unk>']\n",
    "\n",
    "def yield_tokens(dfdata):\n",
    "    for text in dfdata[\"text\"]:\n",
    "        yield tokenizer(text)\n",
    "        \n",
    "\n",
    "        \n",
    "vocab = build_vocab_from_iterator(\n",
    "    yield_tokens(dftrain),\n",
    "    min_freq = MIN_FREQ,\n",
    "    specials=special_symbols,\n",
    "    special_first=True)\n",
    "\n",
    "vocab.set_default_index(UNK_IDX)\n",
    "vocab_size = len(vocab)\n",
    "print(\"vocab_size =\"+str(vocab_size)) \n",
    "\n",
    "#查看词典前20个词\n",
    "#itos: index to string\n",
    "#stoi: string to index\n",
    "print(\"vocab.get_itos():\\n\",vocab.get_itos()[:20])\n",
    "print(\"vocab.get_stoi()['<pad>']:\\n\",vocab.get_stoi()['<pad>'])\n",
    "\n",
    "\n",
    "#3，序列填充\n",
    "def pad(seq,max_length,pad_value=0):\n",
    "    n = len(seq)\n",
    "    result = seq+[pad_value]*max_length\n",
    "    return result[:max_length]\n",
    "\n",
    "\n",
    "#4，编码转换\n",
    "def text_pipeline(text):\n",
    "    words = tokenizer(text)\n",
    "    tokens = vocab(words)\n",
    "    result = pad(tokens,MAX_LEN,PAD_IDX)\n",
    "    return result \n",
    "\n",
    "print(text_pipeline(\"this is an example!\")) \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6675f6e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8338e8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#5，构建管道\n",
    "from torch.utils.data import Dataset,DataLoader\n",
    "\n",
    "class ImdbDataset(Dataset):\n",
    "    def __init__(self,df):\n",
    "        self.df = df\n",
    "    def __len__(self):\n",
    "        return len(self.df)\n",
    "    def __getitem__(self,index):\n",
    "        text = self.df[\"text\"].iloc[index]\n",
    "        label = torch.tensor([self.df[\"label\"].iloc[index]]).float()\n",
    "        tokens = torch.tensor(text_pipeline(text)).int() \n",
    "        return tokens,label\n",
    "    \n",
    "ds_train = ImdbDataset(dftrain)\n",
    "ds_val = ImdbDataset(dfval)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "698296a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True)\n",
    "dl_val = DataLoader(ds_val,batch_size = 50,shuffle = False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b0e7824",
   "metadata": {},
   "outputs": [],
   "source": [
    "for features,labels in dl_train:\n",
    "    break "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fe5ae7f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "8971bba1",
   "metadata": {},
   "source": [
    "### 二，定义模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d014a534",
   "metadata": {},
   "source": [
    "使用Pytorch通常有三种方式构建模型：使用nn.Sequential按层顺序构建模型，继承nn.Module基类构建自定义模型，继承nn.Module基类构建模型并辅助应用模型容器(nn.Sequential,nn.ModuleList,nn.ModuleDict)进行封装。\n",
    "\n",
    "此处选择使用第三种方式进行构建。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15b42dc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn \n",
    "torch.manual_seed(42)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37193036",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    \n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        \n",
    "        #设置padding_idx参数后将在训练过程中将填充的token始终赋值为0向量\n",
    "        self.embedding = nn.Embedding(num_embeddings = vocab_size,embedding_dim = 3,padding_idx = 0)\n",
    "        \n",
    "        self.conv = nn.Sequential()\n",
    "        self.conv.add_module(\"conv_1\",nn.Conv1d(in_channels = 3,out_channels = 16,kernel_size = 5))\n",
    "        self.conv.add_module(\"pool_1\",nn.MaxPool1d(kernel_size = 2))\n",
    "        self.conv.add_module(\"relu_1\",nn.ReLU())\n",
    "        self.conv.add_module(\"conv_2\",nn.Conv1d(in_channels = 16,out_channels = 128,kernel_size = 2))\n",
    "        self.conv.add_module(\"pool_2\",nn.MaxPool1d(kernel_size = 2))\n",
    "        self.conv.add_module(\"relu_2\",nn.ReLU())\n",
    "        \n",
    "        self.dense = nn.Sequential()\n",
    "        self.dense.add_module(\"flatten\",nn.Flatten())\n",
    "        self.dense.add_module(\"linear\",nn.Linear(6144,1))\n",
    "        \n",
    "        \n",
    "    def forward(self,x):\n",
    "        x = self.embedding(x).transpose(1,2)\n",
    "        x = self.conv(x)\n",
    "        y = self.dense(x)\n",
    "        return y\n",
    "        \n",
    "net = Net() \n",
    "print(net)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7f34cc2",
   "metadata": {},
   "source": [
    "```\n",
    "Net(\n",
    "  (embedding): Embedding(8813, 3, padding_idx=0)\n",
    "  (conv): Sequential(\n",
    "    (conv_1): Conv1d(3, 16, kernel_size=(5,), stride=(1,))\n",
    "    (pool_1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
    "    (relu_1): ReLU()\n",
    "    (conv_2): Conv1d(16, 128, kernel_size=(2,), stride=(1,))\n",
    "    (pool_2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
    "    (relu_2): ReLU()\n",
    "  )\n",
    "  (dense): Sequential(\n",
    "    (flatten): Flatten(start_dim=1, end_dim=-1)\n",
    "    (linear): Linear(in_features=6144, out_features=1, bias=True)\n",
    "  )\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "885dc4fc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab57016b",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "from torchkeras import summary \n",
    "summary(net,input_data=features);\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fcc09084",
   "metadata": {},
   "source": [
    "### 三，训练模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d476c85",
   "metadata": {},
   "source": [
    "训练Pytorch通常需要用户编写自定义训练循环，训练循环的代码风格因人而异。\n",
    "\n",
    "有3类典型的训练循环代码风格：脚本形式训练循环，函数形式训练循环，类形式训练循环。\n",
    "\n",
    "此处介绍一种较通用的仿照Keras风格的类形式的训练循环。\n",
    "\n",
    "该训练循环的代码也是torchkeras库的核心代码。\n",
    "\n",
    "torchkeras详情:  https://github.com/lyhue1991/torchkeras \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7436987a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys,time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import datetime \n",
    "from tqdm import tqdm \n",
    "\n",
    "import torch\n",
    "from torch import nn \n",
    "from copy import deepcopy\n",
    "\n",
    "def printlog(info):\n",
    "    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "    print(\"\\n\"+\"==========\"*8 + \"%s\"%nowtime)\n",
    "    print(str(info)+\"\\n\")\n",
    "\n",
    "class StepRunner:\n",
    "    def __init__(self, net, loss_fn,stage = \"train\", metrics_dict = None, \n",
    "                 optimizer = None, lr_scheduler = None\n",
    "                 ):\n",
    "        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage\n",
    "        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler\n",
    "    \n",
    "    def __call__(self, features, labels):\n",
    "        #loss\n",
    "        preds = self.net(features)\n",
    "        loss = self.loss_fn(preds,labels)\n",
    "\n",
    "        #backward()\n",
    "        if self.optimizer is not None and self.stage==\"train\":\n",
    "            loss.backward()\n",
    "            self.optimizer.step()\n",
    "            if self.lr_scheduler is not None:\n",
    "                self.lr_scheduler.step()\n",
    "            self.optimizer.zero_grad()\n",
    "            \n",
    "        #metrics\n",
    "        step_metrics = {self.stage+\"_\"+name:metric_fn(preds, labels).item() \n",
    "                        for name,metric_fn in self.metrics_dict.items()}\n",
    "        return loss.item(),step_metrics\n",
    "\n",
    "\n",
    "class EpochRunner:\n",
    "    def __init__(self,steprunner):\n",
    "        self.steprunner = steprunner\n",
    "        self.stage = steprunner.stage\n",
    "        self.steprunner.net.train() if self.stage==\"train\" else self.steprunner.net.eval()\n",
    "        \n",
    "    def __call__(self,dataloader):\n",
    "        total_loss,step = 0,0\n",
    "        loop = tqdm(enumerate(dataloader), total =len(dataloader))\n",
    "        for i, batch in loop: \n",
    "            if self.stage==\"train\":\n",
    "                loss, step_metrics = self.steprunner(*batch)\n",
    "            else:\n",
    "                with torch.no_grad():\n",
    "                    loss, step_metrics = self.steprunner(*batch)\n",
    "            step_log = dict({self.stage+\"_loss\":loss},**step_metrics)\n",
    "\n",
    "            total_loss += loss\n",
    "            step+=1\n",
    "            if i!=len(dataloader)-1:\n",
    "                loop.set_postfix(**step_log)\n",
    "            else:\n",
    "                epoch_loss = total_loss/step\n",
    "                epoch_metrics = {self.stage+\"_\"+name:metric_fn.compute().item() \n",
    "                                 for name,metric_fn in self.steprunner.metrics_dict.items()}\n",
    "                epoch_log = dict({self.stage+\"_loss\":epoch_loss},**epoch_metrics)\n",
    "                loop.set_postfix(**epoch_log)\n",
    "\n",
    "                for name,metric_fn in self.steprunner.metrics_dict.items():\n",
    "                    metric_fn.reset()\n",
    "        return epoch_log\n",
    "\n",
    "class KerasModel(torch.nn.Module):\n",
    "    def __init__(self,net,loss_fn,metrics_dict=None,optimizer=None,lr_scheduler = None):\n",
    "        super().__init__()\n",
    "        self.history = {}\n",
    "        \n",
    "        self.net = net\n",
    "        self.loss_fn = loss_fn\n",
    "        self.metrics_dict = nn.ModuleDict(metrics_dict) \n",
    "        \n",
    "        self.optimizer = optimizer if optimizer is not None else torch.optim.Adam(\n",
    "            self.parameters(), lr=1e-2)\n",
    "        self.lr_scheduler = lr_scheduler\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.net:\n",
    "            return self.net.forward(x)\n",
    "        else:\n",
    "            raise NotImplementedError\n",
    "\n",
    "\n",
    "    def fit(self, train_data, val_data=None, epochs=10, ckpt_path='checkpoint.pt', \n",
    "            patience=5, monitor=\"val_loss\", mode=\"min\"):\n",
    "\n",
    "        for epoch in range(1, epochs+1):\n",
    "            printlog(\"Epoch {0} / {1}\".format(epoch, epochs))\n",
    "            \n",
    "            # 1，train -------------------------------------------------  \n",
    "            train_step_runner = StepRunner(net = self.net,stage=\"train\",\n",
    "                    loss_fn = self.loss_fn,metrics_dict=deepcopy(self.metrics_dict),\n",
    "                    optimizer = self.optimizer, lr_scheduler = self.lr_scheduler)\n",
    "            train_epoch_runner = EpochRunner(train_step_runner)\n",
    "            train_metrics = train_epoch_runner(train_data)\n",
    "            \n",
    "            for name, metric in train_metrics.items():\n",
    "                self.history[name] = self.history.get(name, []) + [metric]\n",
    "\n",
    "            # 2，validate -------------------------------------------------\n",
    "            if val_data:\n",
    "                val_step_runner = StepRunner(net = self.net,stage=\"val\",\n",
    "                    loss_fn = self.loss_fn,metrics_dict=deepcopy(self.metrics_dict))\n",
    "                val_epoch_runner = EpochRunner(val_step_runner)\n",
    "                with torch.no_grad():\n",
    "                    val_metrics = val_epoch_runner(val_data)\n",
    "                val_metrics[\"epoch\"] = epoch\n",
    "                for name, metric in val_metrics.items():\n",
    "                    self.history[name] = self.history.get(name, []) + [metric]\n",
    "            \n",
    "            # 3，early-stopping -------------------------------------------------\n",
    "            if not val_data:\n",
    "                continue\n",
    "            arr_scores = self.history[monitor]\n",
    "            best_score_idx = np.argmax(arr_scores) if mode==\"max\" else np.argmin(arr_scores)\n",
    "            if best_score_idx==len(arr_scores)-1:\n",
    "                torch.save(self.net.state_dict(),ckpt_path)\n",
    "                print(\"<<<<<< reach best {0} : {1} >>>>>>\".format(monitor,\n",
    "                     arr_scores[best_score_idx]),file=sys.stderr)\n",
    "            if len(arr_scores)-best_score_idx>patience:\n",
    "                print(\"<<<<<< {} without improvement in {} epoch, early stopping >>>>>>\".format(\n",
    "                    monitor,patience),file=sys.stderr)\n",
    "                break \n",
    "                \n",
    "        self.net.load_state_dict(torch.load(ckpt_path))  \n",
    "        return pd.DataFrame(self.history)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def evaluate(self, val_data):\n",
    "        val_step_runner = StepRunner(net = self.net,stage=\"val\",\n",
    "                    loss_fn = self.loss_fn,metrics_dict=deepcopy(self.metrics_dict))\n",
    "        val_epoch_runner = EpochRunner(val_step_runner)\n",
    "        val_metrics = val_epoch_runner(val_data)\n",
    "        return val_metrics\n",
    "        \n",
    "       \n",
    "    @torch.no_grad()\n",
    "    def predict(self, dataloader):\n",
    "        self.net.eval()\n",
    "        result = torch.cat([self.forward(t[0]) for t in dataloader])\n",
    "        return result.data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2900cc1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchmetrics \n",
    "\n",
    "class Accuracy(torchmetrics.Accuracy):\n",
    "    def __init__(self, dist_sync_on_step=False):\n",
    "        super().__init__(dist_sync_on_step=dist_sync_on_step)\n",
    "        \n",
    "    def update(self, preds: torch.Tensor, targets: torch.Tensor):\n",
    "        super().update(torch.sigmoid(preds),targets.long())\n",
    "            \n",
    "    def compute(self):\n",
    "        return super().compute()\n",
    "    \n",
    "net = Net() \n",
    "model = KerasModel(net,\n",
    "                  loss_fn = nn.BCEWithLogitsLoss(),\n",
    "                  optimizer= torch.optim.Adam(net.parameters(),lr = 0.01),  \n",
    "                  metrics_dict = {\"acc\":Accuracy()}\n",
    "                )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8a04882",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fit(dl_train,\n",
    "    val_data=dl_val,\n",
    "    epochs=10,\n",
    "    ckpt_path='checkpoint.pt',\n",
    "    patience=3,\n",
    "    monitor='val_acc',\n",
    "    mode='max')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b0c6be1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9abe3458",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "1ff35951",
   "metadata": {},
   "source": [
    "### 四，评估模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff9d1407",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "\n",
    "history = model.history\n",
    "dfhistory = pd.DataFrame(history) \n",
    "dfhistory \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93107420",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_metric(dfhistory, metric):\n",
    "    train_metrics = dfhistory[\"train_\"+metric]\n",
    "    val_metrics = dfhistory['val_'+metric]\n",
    "    epochs = range(1, len(train_metrics) + 1)\n",
    "    plt.plot(epochs, train_metrics, 'bo--')\n",
    "    plt.plot(epochs, val_metrics, 'ro-')\n",
    "    plt.title('Training and validation '+ metric)\n",
    "    plt.xlabel(\"Epochs\")\n",
    "    plt.ylabel(metric)\n",
    "    plt.legend([\"train_\"+metric, 'val_'+metric])\n",
    "    plt.show()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51029d7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metric(dfhistory,\"loss\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18385277",
   "metadata": {},
   "source": [
    "![](./data/1-3-loss曲线.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94294cb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metric(dfhistory,\"acc\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2960e2b",
   "metadata": {},
   "source": [
    "![](./data/1-3-accuracy曲线.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a148942",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 评估\n",
    "model.evaluate(dl_val)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2796b76",
   "metadata": {},
   "source": [
    "```\n",
    "{'val_loss': 0.36953783154487607, 'val_acc': 0.848800003528595}\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f2045e5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f1bc62f0",
   "metadata": {},
   "source": [
    "### 五，使用模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6a67857",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(net,dl):\n",
    "    net.eval()\n",
    "    with torch.no_grad():\n",
    "        result = nn.Sigmoid()(torch.cat([net.forward(t[0]) for t in dl]))\n",
    "    return(result.data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f916311",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred_probs = predict(net,dl_val)\n",
    "y_pred_probs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9976d04f",
   "metadata": {},
   "source": [
    "```\n",
    "tensor([[0.5638],\n",
    "        [0.9990],\n",
    "        [0.9573],\n",
    "        ...,\n",
    "        [0.9188],\n",
    "        [0.8004],\n",
    "        [0.9998]])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59b69aba",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "31c888a8",
   "metadata": {},
   "source": [
    "### 六，保存模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fee13dc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#模型权重已经被保存在了ckpt_path='checkpoint.pt'\n",
    "net_clone = Net()\n",
    "net_clone.load_state_dict(torch.load('checkpoint.pt'))\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bf6c7de",
   "metadata": {},
   "source": [
    "**如果本书对你有所帮助，想鼓励一下作者，记得给本项目加一颗星星star⭐️，并分享给你的朋友们喔😊!** \n",
    "\n",
    "如果对本书内容理解上有需要进一步和作者交流的地方，欢迎在公众号\"算法美食屋\"下留言。作者时间和精力有限，会酌情予以回复。\n",
    "\n",
    "也可以在公众号后台回复关键字：**加群**，加入读者交流群和大家讨论。\n",
    "\n",
    "![算法美食屋logo.png](https://tva1.sinaimg.cn/large/e6c9d24egy1h41m2zugguj20k00b9q46.jpg)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "ipynb,md",
   "main_language": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
